library(tidyverse)
library(ComplexHeatmap)
library(circlize)
library(pheatmap)
library(corrplot)
library(ggplot2)
library(reshape2)
library(grid)
library(dplyr)


# Load the tables before running the R script.

#------------------- start -----------------------

# col_order
col_order = c(
  "Bf_anterior forebrain",
  "Bf_infundibular organ (IO)",
  "Bf_dorsal layer",
  "Bf_cells of Joseph",
  "Bf_dorsal ventricular cavity",
  "Bf_translumenal cells I (anterior)",
  "Bf_boundary cells I",
  "Bf_boundary cells II",
  "Bf_translumenal cells II (posterior-upper)",
  "Bf_translumenal cells III (posterior-lower)",
  "Bf_central canal",
  "Bf_nucleus of Rohde (nRo)",
  "Bf_migrated cells",
  "Bf_others"
)

for (i in seq_along(df_filtered_list)) {
  

df <- df_filtered_list[[i]]

# sort 
sort_matrix_columnwise <- function(df, threshold = 0.2, col_order = NULL) {
  df$orig_order <- 1:nrow(df)
  result <- data.frame()
  to_sort <- df
  
  if (is.null(col_order)) {
    col_order <- colnames(df)[-ncol(df)]
  }
  
  for (col in col_order) {
    to_sort <- to_sort[order(to_sort[[col]], decreasing = TRUE), ]
    
    fix_now <- to_sort[[col]] > threshold
    
    if (any(fix_now)) {
      result <- rbind(result, to_sort[fix_now, ])
      to_sort <- to_sort[!fix_now, ]
    }
    
    if (nrow(to_sort) == 0) break
  }
  
  if (nrow(to_sort) > 0) {
    result <- rbind(result, to_sort)
  }
  
  result$orig_order <- NULL
  return(result)
}

df <- sort_matrix_columnwise(df, threshold = 0.2, col_order = col_order)


top_anno <- HeatmapAnnotation(
  type = colnames(df_filtered_list[[i]]),
  col = list(type = c(
    "Bf_anterior forebrain" = "#CF8EB1",
    "Bf_boundary cells I" = "#CF8E4D",
    "Bf_boundary cells II" = "#A65628",                      
    "Bf_central canal" = "#549F6C",
    "Bf_dorsal layer" = "#FF7F00",
    "Bf_dorsal ventricular cavity" = "#F1E330",
    "Bf_infundibular organ (IO)" = "#E41A1C",
    "Bf_cells of Joseph" = "#C1862A",
    "Bf_translumenal cells I (anterior)" = "#FECB20",
    "Bf_translumenal cells II (posterior-upper)" = "#984EA3",
    "Bf_translumenal cells III (posterior-lower)" = "#728A7A",
    "Bf_migrated cells" = "#BFD1C4",
    "Bf_nucleus of Rohde (nRo)" = "#377EB8",
    "Bf_others" = "#9C3EFF"
  )
  ),
  annotation_name_gp = gpar(fontsize = 8),
  simple_anno_size = unit(2.5, "mm")
)



# generate heatmap


ht = Heatmap(as.matrix(df),
        #top_annotation = level1_top_anno,
        col = colorRamp2(seq(0, 1, length.out = 25), (hcl.colors(25, "Oslo"))),
        name = Legend_name[[i]],
        column_names_rot = 45,
        column_order = col_order,
        column_names_side = "top",
        row_names_side = "right",
        row_order = rownames(df),
        show_row_dend = TRUE,
        width = ncol(df) * unit(8, "mm"),
        height = nrow(df) * unit(4, "mm"),
        rect_gp = gpar(col = "#aaa", lwd = 1),
        row_names_gp = gpar(fontsize = 10),
        column_names_gp = gpar(fontsize = 10),
        show_heatmap_legend = TRUE,
        cell_fun = function(j, i, x, y, width, height, fill) {
          if (df[i, j] > 0.2) {
            grid.points(x, y, pch = 16, size = unit(1.5, "mm"), gp = gpar(col = "#ff3"))
          }
        },
        border = TRUE)

# export

ht_drawn = draw(ht,
                show_heatmap_legend = FALSE,
                show_annotation_legend = FALSE)


w = ComplexHeatmap:::width(ht_drawn)
h = ComplexHeatmap:::height(ht_drawn)

w_inch = convertWidth(w, "inch", valueOnly = TRUE)
h_inch = convertHeight(h, "inch", valueOnly = TRUE)

path <- paste0("./Mouse_SubClass vs Bf Xenium_Level1_",Legend_name[[i]],".pdf")



pdf(path, width = w_inch*1.1, height = h_inch*1.1)
draw(ht,
     show_heatmap_legend = FALSE,
     show_annotation_legend = FALSE)
dev.off()

}


dev.off()


